{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Decision Tree Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Trees are a popular class of algorithm in Machine Learning. In this notebook we build a simple Decision Tree Classifier using `scikit-learn` to show that they can be executed homomorphically using Concrete Numpy.\n",
    "\n",
    "State of the art classifiers are generally a bit more complex than a single decision tree, but here we wanted to demonstrate FHE decision trees so results may not compete with the best models out there.\n",
    "\n",
    "Converting a tree working over quantized data to its FHE equivalent takes only a few lines of code thanks to Concrete Numpy.\n",
    "\n",
    "Let's dive in!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Use Case"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The use case is a spam classification task from OpenML you can find here: https://www.openml.org/d/44\n",
    "\n",
    "Some pre-extracted features (like some word frequencies) are provided as well as a class, `0` for a normal e-mail and `1` for spam, for 4601 samples.\n",
    "\n",
    "Let's first get the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4601, 57)\n",
      "(4601,)\n",
      "Number of features: 57\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "features, classes = fetch_openml(data_id=44, as_frame=False, cache=True, return_X_y=True)\n",
    "classes = classes.astype(numpy.int64)\n",
    "\n",
    "print(features.shape)\n",
    "print(classes.shape)\n",
    "\n",
    "num_features = features.shape[1]\n",
    "print(f\"Number of features: {num_features}\")\n",
    "\n",
    "x_train, x_test, y_train, y_test = train_test_split(\n",
    "    features,\n",
    "    classes,\n",
    "    test_size=0.15,\n",
    "    random_state=42,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first train a decision tree on the dataset as is and see what performance we can get."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Depth: 29\n",
      "Mean accuracy: 0.91027496382055\n",
      "Number of test samples: 691\n",
      "Number of spams in test samples: 304\n",
      "True Negative (legit mail well classified) rate: 0.9328165374677002\n",
      "False Positive (legit mail classified as spam) rate: 0.06718346253229975\n",
      "False Negative (spam mail classified as legit) rate: 0.11842105263157894\n",
      "True Positive (spam well classified) rate: 0.881578947368421\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "clear_clf = DecisionTreeClassifier()\n",
    "clear_clf = clear_clf.fit(x_train, y_train)\n",
    "\n",
    "print(f\"Depth: {clear_clf.get_depth()}\")\n",
    "\n",
    "preds = clear_clf.predict(x_test)\n",
    "\n",
    "mean_accuracy = numpy.mean(preds == y_test)\n",
    "print(f\"Mean accuracy: {mean_accuracy}\")\n",
    "\n",
    "true_negative, false_positive, false_negative, true_positive = confusion_matrix(\n",
    "    y_test, preds, normalize=\"true\"\n",
    ").ravel()\n",
    "\n",
    "num_samples = len(y_test)\n",
    "num_spam = sum(y_test)\n",
    "\n",
    "print(f\"Number of test samples: {num_samples}\")\n",
    "print(f\"Number of spams in test samples: {num_spam}\")\n",
    "\n",
    "print(f\"True Negative (legit mail well classified) rate: {true_negative}\")\n",
    "print(f\"False Positive (legit mail classified as spam) rate: {false_positive}\")\n",
    "print(f\"False Negative (spam mail classified as legit) rate: {false_negative}\")\n",
    "print(f\"True Positive (spam well classified) rate: {true_positive}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now quantize the features to train the tree directly on quantized data, this will make the trained tree FHE friendly by default which is a nice bonus, as well as allowing to see how both trees compare to each other.\n",
    "\n",
    "The choice here is to compute the quantization parameters over the training set. We use 6 bits for each feature individually as the Concrete Numpy precision for PBSes is better for 6 bits of precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0  0  6  0  3  5  0  0  0  2  0 19  0  0  0  0  0  0  3  0  0  0  0  0\n",
      "  4  4  0  7  3  0  0  0  2  0  0  4  0  0  0  0  0  0  0  0  0  0  0  0\n",
      "  0  1  0  0  0  0  0  0  1]\n",
      "[ 0  0  0  0  6  0  0  0  0  0  0 10  0  0  0  0  0  0  4  0  7  0  0  0\n",
      "  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0\n",
      "  0  0  0  0  0  0  0  0  0]\n"
     ]
    }
   ],
   "source": [
    "from concrete.quantization import QuantizedArray\n",
    "\n",
    "# And quantize accordingly training and test samples\n",
    "q_x_train = numpy.zeros_like(x_train, dtype=numpy.int64)\n",
    "q_x_test = numpy.zeros_like(x_test, dtype=numpy.int64)\n",
    "for feature_idx in range(num_features):\n",
    "    q_x_train[:, feature_idx] = QuantizedArray(6, x_train[:, feature_idx]).qvalues\n",
    "    q_x_test[:, feature_idx] = QuantizedArray(6, x_test[:, feature_idx]).qvalues\n",
    "\n",
    "print(q_x_train[0])\n",
    "print(q_x_test[-1])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far so good, we can now train a DecisionTreeClassifier on the quantized dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Depth: 7\n",
      "Mean accuracy: 0.8813314037626628\n",
      "Number of test samples: 691\n",
      "Number of spams in test samples: 304\n",
      "True Negative (legit mail well classified) rate: 0.9276485788113695\n",
      "False Positive (legit mail classified as spam) rate: 0.07235142118863049\n",
      "False Negative (spam mail classified as legit) rate: 0.17763157894736842\n",
      "True Positive (spam well classified) rate: 0.8223684210526315\n"
     ]
    }
   ],
   "source": [
    "# We limit the depth to have reasonable FHE runtimes, but deep trees can still compile properly!\n",
    "clf = DecisionTreeClassifier(max_depth=7)\n",
    "clf = clf.fit(q_x_train, y_train)\n",
    "\n",
    "print(f\"Depth: {clf.get_depth()}\")\n",
    "\n",
    "preds = clf.predict(q_x_test)\n",
    "\n",
    "mean_accuracy = numpy.mean(preds == y_test)\n",
    "print(f\"Mean accuracy: {mean_accuracy}\")\n",
    "\n",
    "true_negative, false_positive, false_negative, true_positive = confusion_matrix(\n",
    "    y_test, preds, normalize=\"true\"\n",
    ").ravel()\n",
    "\n",
    "num_samples = len(y_test)\n",
    "num_spam = sum(y_test)\n",
    "\n",
    "print(f\"Number of test samples: {num_samples}\")\n",
    "print(f\"Number of spams in test samples: {num_spam}\")\n",
    "\n",
    "print(f\"True Negative (legit mail well classified) rate: {true_negative}\")\n",
    "print(f\"False Positive (legit mail classified as spam) rate: {false_positive}\")\n",
    "print(f\"False Negative (spam mail classified as legit) rate: {false_negative}\")\n",
    "print(f\"True Positive (spam well classified) rate: {true_positive}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This simple classifier achieves about a 7% false positive (legit mail classified as spam) rate and about a 17% false negative (spam mail classified as legit) rate. In a more common setting, not shown in this tutorial, we would use gradient boosting to assemble several small classifiers into a single one that would be more effective.\n",
    "\n",
    "We can see that the accuracy is relatively similar to the tree trained in the clear despite the quantization (to be FHE compatible) and smaller depth to allow for faster FHE computations. The main difference being a higher False Positive rate (legit mail classified as spam).\n",
    "\n",
    "The point here is not to beat the state of the art methods for spam detection but rather show that given a certain tree classifier we can run it homomorphically."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homorphic Trees"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we can do that we need to convert the tree to a form that is easy to run homomorphically.\n",
    "\n",
    "The Hummingbird paper from Microsoft (https://scnakandala.github.io/papers/TR_2020_Hummingbird.pdf and https://github.com/microsoft/hummingbird) gives a method to convert tree evaluation to tensor operations which we support in Concrete Numpy.\n",
    "\n",
    "The next few cells implement the functions necessary for the conversion. They are not optimized well so that they remain readable.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First an sklearn import we need\n",
    "from sklearn.tree import _tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hummingbird_tensor_a(tree_, features, internal_nodes):\n",
    "    \"\"\"Create Hummingbird tensor A.\"\"\"\n",
    "    a = numpy.zeros((len(features), len(internal_nodes)), dtype=numpy.int64)\n",
    "    for i in range(a.shape[0]):\n",
    "        for j in range(a.shape[1]):\n",
    "            a[i, j] = tree_.feature[internal_nodes[j]] == features[i]\n",
    "\n",
    "    return a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hummingbird_tensor_b(tree_, internal_nodes, is_integer_tree=False):\n",
    "    \"\"\"Create Hummingbird tensor B.\"\"\"\n",
    "    b = numpy.array([tree_.threshold[int_node] for int_node in internal_nodes])\n",
    "\n",
    "    return b.astype(numpy.int64) if is_integer_tree else b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_subtree_nodes_set_per_node(\n",
    "    all_nodes, leaf_nodes, is_left_child_of: dict, is_right_child_of: dict\n",
    "):\n",
    "    \"\"\"Create subtrees nodes set for each node in the tree.\"\"\"\n",
    "    left_subtree_nodes_per_node = {node: set() for node in all_nodes}\n",
    "    right_subtree_nodes_per_node = {node: set() for node in all_nodes}\n",
    "\n",
    "    current_nodes = {node: None for node in leaf_nodes}\n",
    "    while current_nodes:\n",
    "        next_nodes = {}\n",
    "        for node in current_nodes:\n",
    "            parent_as_left_child = is_left_child_of.get(node, None)\n",
    "            if parent_as_left_child is not None:\n",
    "                left_subtree = left_subtree_nodes_per_node[parent_as_left_child]\n",
    "                left_subtree.add(node)\n",
    "                left_subtree.update(left_subtree_nodes_per_node[node])\n",
    "                left_subtree.update(right_subtree_nodes_per_node[node])\n",
    "                next_nodes.update({parent_as_left_child: None})\n",
    "\n",
    "            parent_as_right_child = is_right_child_of.get(node, None)\n",
    "            if parent_as_right_child is not None:\n",
    "                right_subtree = right_subtree_nodes_per_node[parent_as_right_child]\n",
    "                right_subtree.add(node)\n",
    "                right_subtree.update(left_subtree_nodes_per_node[node])\n",
    "                right_subtree.update(right_subtree_nodes_per_node[node])\n",
    "                next_nodes.update({parent_as_right_child: None})\n",
    "\n",
    "        current_nodes = next_nodes\n",
    "\n",
    "    return left_subtree_nodes_per_node, right_subtree_nodes_per_node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hummingbird_tensor_c(\n",
    "    all_nodes, internal_nodes, leaf_nodes, is_left_child_of: dict, is_right_child_of: dict\n",
    "):\n",
    "    \"\"\"Create Hummingbird tensor C.\"\"\"\n",
    "    left_subtree_nodes_per_node, right_subtree_nodes_per_node = create_subtree_nodes_set_per_node(\n",
    "        all_nodes, leaf_nodes, is_left_child_of, is_right_child_of\n",
    "    )\n",
    "\n",
    "    c = numpy.zeros((len(internal_nodes), len(leaf_nodes)), dtype=numpy.int64)\n",
    "\n",
    "    for i in range(c.shape[0]):\n",
    "        for j in range(c.shape[1]):\n",
    "            if leaf_nodes[j] in right_subtree_nodes_per_node[internal_nodes[i]]:\n",
    "                c[i, j] = -1\n",
    "            elif leaf_nodes[j] in left_subtree_nodes_per_node[internal_nodes[i]]:\n",
    "                c[i, j] = 1\n",
    "\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hummingbird_tensor_d(leaf_nodes, is_left_child_of, is_right_child_of):\n",
    "    \"\"\"Create Hummingbird tensor D.\"\"\"\n",
    "    d = numpy.zeros((len(leaf_nodes)), dtype=numpy.int64)\n",
    "    for k in range(d.shape[0]):\n",
    "        current_node = leaf_nodes[k]\n",
    "        num_left_children = 0\n",
    "        while True:\n",
    "            if (parent_as_left_child := is_left_child_of.get(current_node, None)) is not None:\n",
    "                num_left_children += 1\n",
    "                current_node = parent_as_left_child\n",
    "            elif (parent_as_right_child := is_right_child_of.get(current_node, None)) is not None:\n",
    "                current_node = parent_as_right_child\n",
    "            else:\n",
    "                break\n",
    "        d[k] = num_left_children\n",
    "\n",
    "    return d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_hummingbird_tensor_e(tree_, leaf_nodes, classes):\n",
    "    \"\"\"Create Hummingbird tensor E.\"\"\"\n",
    "    e = numpy.zeros((len(leaf_nodes), len(classes)), dtype=numpy.int64)\n",
    "    for i in range(e.shape[0]):\n",
    "        leaf_node = leaf_nodes[i]\n",
    "        assert tree_.feature[leaf_node] == _tree.TREE_UNDEFINED  # Sanity check\n",
    "        for j in range(e.shape[1]):\n",
    "            value = None\n",
    "            if tree_.n_outputs == 1:\n",
    "                value = tree_.value[leaf_node][0]\n",
    "            else:\n",
    "                value = tree_.value[leaf_node].T[0]\n",
    "            class_name = numpy.argmax(value)\n",
    "            e[i, j] = class_name == j\n",
    "\n",
    "    return e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tree_to_numpy(tree, num_features, classes):\n",
    "    \"\"\"Convert an sklearn tree to its Hummingbird tensor equivalent.\"\"\"\n",
    "    tree_ = tree.tree_\n",
    "\n",
    "    number_of_nodes = tree_.node_count\n",
    "    all_nodes = list(range(number_of_nodes))\n",
    "    internal_nodes = [\n",
    "        node_idx\n",
    "        for node_idx, feature in enumerate(tree_.feature)\n",
    "        if feature != _tree.TREE_UNDEFINED\n",
    "    ]\n",
    "    leaf_nodes = [\n",
    "        node_idx\n",
    "        for node_idx, feature in enumerate(tree_.feature)\n",
    "        if feature == _tree.TREE_UNDEFINED\n",
    "    ]\n",
    "\n",
    "    features = list(range(num_features))\n",
    "\n",
    "    a = create_hummingbird_tensor_a(tree_, features, internal_nodes)\n",
    "\n",
    "    b = create_hummingbird_tensor_b(tree_, internal_nodes, is_integer_tree=True)\n",
    "\n",
    "    is_left_child_of = {\n",
    "        left_child: parent\n",
    "        for parent, left_child in enumerate(tree_.children_left)\n",
    "        if left_child != _tree.TREE_UNDEFINED\n",
    "    }\n",
    "    is_right_child_of = {\n",
    "        right_child: parent\n",
    "        for parent, right_child in enumerate(tree_.children_right)\n",
    "        if right_child != _tree.TREE_UNDEFINED\n",
    "    }\n",
    "\n",
    "    c = create_hummingbird_tensor_c(\n",
    "        all_nodes, internal_nodes, leaf_nodes, is_left_child_of, is_right_child_of\n",
    "    )\n",
    "\n",
    "    d = create_hummingbird_tensor_d(leaf_nodes, is_left_child_of, is_right_child_of)\n",
    "\n",
    "    e = create_hummingbird_tensor_e(tree_, leaf_nodes, classes)\n",
    "\n",
    "    def tree_predict(inputs):\n",
    "        t = inputs @ a\n",
    "        t = t <= b\n",
    "        t = t @ c\n",
    "        t = t == d\n",
    "        r = t @ e\n",
    "        return r\n",
    "\n",
    "    return tree_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can finally convert our tree!\n",
    "tree_predict = tree_to_numpy(clf, num_features, classes=[0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results are identical: True\n"
     ]
    }
   ],
   "source": [
    "# Let's see if it works as expected\n",
    "tensor_predictions = tree_predict(q_x_test)\n",
    "tensor_predictions = numpy.argmax(tensor_predictions, axis=1)\n",
    "\n",
    "tree_predictions = clf.predict(q_x_test)\n",
    "\n",
    "print(f\"Results are identical: {numpy.array_equal(tensor_predictions, tree_predictions)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have a tensor equivalent of our `DecisionTreeClassifier`, pretty neat isn't it?\n",
    "\n",
    "Last step is compiling the tensor equivalent to FHE using the Concrete Numpy and it's nearly as easy as 1, 2, 3.\n",
    "\n",
    "We use the training input data as well as some synthetic data to calibrate the circuit during compilation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import concrete.numpy as hnp\n",
    "\n",
    "compiler = hnp.NPFHECompiler(tree_predict, {\"inputs\": \"encrypted\"})\n",
    "fhe_tree = compiler.compile_on_inputset((sample for sample in q_x_train))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we can start running the tree homomorphically!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [05:01<00:00, 30.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Same predictions of FHE compared to clear: 10/10 (1.0)\n",
      "FHE evaluation #1 took 30.765692999993917 s\n",
      "FHE evaluation #2 took 30.604038099998434 s\n",
      "FHE evaluation #3 took 30.70741419999831 s\n",
      "FHE evaluation #4 took 30.64609560000099 s\n",
      "FHE evaluation #5 took 29.945520399996894 s\n",
      "FHE evaluation #6 took 30.155333900002006 s\n",
      "FHE evaluation #7 took 29.776400299997476 s\n",
      "FHE evaluation #8 took 30.12118709999777 s\n",
      "FHE evaluation #9 took 29.526597299998684 s\n",
      "FHE evaluation #10 took 29.392055899996194 s\n",
      "Mean FHE evaluation time: 30.16403357999807\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from time import perf_counter\n",
    "\n",
    "num_runs = 10\n",
    "fhe_preds = []\n",
    "clear_preds = []\n",
    "fhe_eval_times = []\n",
    "for i in tqdm(range(num_runs)):\n",
    "    start = perf_counter()\n",
    "    fhe_pred = fhe_tree.run(q_x_test[i].astype(numpy.uint8))\n",
    "    stop = perf_counter()\n",
    "    fhe_eval_times.append(stop - start)\n",
    "    fhe_pred = numpy.argmax(fhe_pred)\n",
    "    fhe_preds.append(fhe_pred)\n",
    "    clear_pred = clf.predict(numpy.expand_dims(q_x_test[i], axis=0))\n",
    "    clear_pred = clear_pred[0]\n",
    "    clear_preds.append(clear_pred)\n",
    "\n",
    "fhe_preds = numpy.array(fhe_preds)\n",
    "clear_preds = numpy.array(clear_preds)\n",
    "\n",
    "same_preds = fhe_preds == clear_preds\n",
    "n_same_preds = sum(same_preds)\n",
    "print(\n",
    "    f\"Same predictions of FHE compared to clear: {n_same_preds}/{num_runs} \"\n",
    "    f\"({numpy.mean(same_preds)})\"\n",
    ")\n",
    "for idx, eval_time in enumerate(fhe_eval_times, 1):\n",
    "    print(f\"FHE evaluation #{idx} took {eval_time} s\")\n",
    "\n",
    "print(f\"Mean FHE evaluation time: {numpy.mean(fhe_eval_times)}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we showed how to quantize a dataset to train a tree directly on integer data so that it is FHE friendly. We saw that despite quantization and its smaller depth, the quantized tree classification capabilities were close to a tree trained on the original real-valued dataset.\n",
    "\n",
    "We then used the Hummingbird paper's algorithm to transform a tree evaluation to a few tensor operations which can be compiled by the Concrete Numpy to an FHE circuit.\n",
    "\n",
    "Finally we ran the compiled circuit on a few samples (because inference times are a bit high) to show that clear and FHE computations were the same."
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
